{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import nltk\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(S (PRP I) (VP (VBP am) (NNP Sam)) (. .))\n",
      "(PRP I)\n",
      "(VP (VBP am) (NNP Sam))\n",
      "(VBP am)\n",
      "(NNP Sam)\n",
      "(. .)\n",
      "(S (PRP I) (VP (VBP am) (NNP Sam)) (. .))\n"
     ]
    }
   ],
   "source": [
    "string = \"(ROOT(S(NP (PRP I))(VP (VBP am)(NP (NNP Sam)))(. .)))\"\n",
    "tree = nltk.tree.Tree.fromstring(string)\n",
    "\n",
    "def load_compressed_tree(s):\n",
    "\n",
    "    def compress_tree(tree):\n",
    "        if len(tree) == 1:\n",
    "            if isinstance(tree[0], nltk.tree.Tree):\n",
    "                return compress_tree(tree[0])\n",
    "            else:\n",
    "                return tree\n",
    "        else:\n",
    "            for i, t in enumerate(tree):\n",
    "                tree[i] = compress_tree(t)\n",
    "            return tree\n",
    "\n",
    "    return compress_tree(nltk.tree.Tree.fromstring(s))\n",
    "tree = load_compressed_tree(string)\n",
    "for t in tree.subtrees():\n",
    "    print(t)\n",
    "    \n",
    "print(str(tree))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(ROOT I am Sam .)\n"
     ]
    }
   ],
   "source": [
    "print(tree.flatten())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['ROOT', 'S', 'NP', 'PRP', 'VP', 'VBP', 'NP', 'NNP', '.']\n"
     ]
    }
   ],
   "source": [
    "print(list(t.label() for t in tree.subtrees()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import json\n",
    "d = json.load(open(\"data/squad/shared_dev.json\", 'r'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "73"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(d['pos_counter'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'#': 6,\n",
       " '$': 80,\n",
       " \"''\": 1291,\n",
       " ',': 14136,\n",
       " '-LRB-': 1926,\n",
       " '-RRB-': 1925,\n",
       " '.': 9505,\n",
       " ':': 1455,\n",
       " 'ADJP': 3426,\n",
       " 'ADVP': 4936,\n",
       " 'CC': 9300,\n",
       " 'CD': 6216,\n",
       " 'CONJP': 191,\n",
       " 'DT': 26286,\n",
       " 'EX': 288,\n",
       " 'FRAG': 107,\n",
       " 'FW': 96,\n",
       " 'IN': 32564,\n",
       " 'INTJ': 12,\n",
       " 'JJ': 21452,\n",
       " 'JJR': 563,\n",
       " 'JJS': 569,\n",
       " 'LS': 7,\n",
       " 'LST': 1,\n",
       " 'MD': 1051,\n",
       " 'NAC': 19,\n",
       " 'NN': 34750,\n",
       " 'NNP': 28392,\n",
       " 'NNPS': 1400,\n",
       " 'NNS': 16716,\n",
       " 'NP': 91636,\n",
       " 'NP-TMP': 236,\n",
       " 'NX': 108,\n",
       " 'PDT': 89,\n",
       " 'POS': 1451,\n",
       " 'PP': 33278,\n",
       " 'PRN': 2085,\n",
       " 'PRP': 2320,\n",
       " 'PRP$': 1959,\n",
       " 'PRT': 450,\n",
       " 'QP': 838,\n",
       " 'RB': 7611,\n",
       " 'RBR': 301,\n",
       " 'RBS': 252,\n",
       " 'ROOT': 9587,\n",
       " 'RP': 454,\n",
       " 'RRC': 19,\n",
       " 'S': 21557,\n",
       " 'SBAR': 5009,\n",
       " 'SBARQ': 6,\n",
       " 'SINV': 135,\n",
       " 'SQ': 5,\n",
       " 'SYM': 17,\n",
       " 'TO': 5167,\n",
       " 'UCP': 143,\n",
       " 'UH': 15,\n",
       " 'VB': 4197,\n",
       " 'VBD': 8377,\n",
       " 'VBG': 3570,\n",
       " 'VBN': 7218,\n",
       " 'VBP': 2897,\n",
       " 'VBZ': 4146,\n",
       " 'VP': 33696,\n",
       " 'WDT': 1368,\n",
       " 'WHADJP': 5,\n",
       " 'WHADVP': 439,\n",
       " 'WHNP': 1927,\n",
       " 'WHPP': 153,\n",
       " 'WP': 482,\n",
       " 'WP$': 50,\n",
       " 'WRB': 442,\n",
       " 'X': 23,\n",
       " '``': 1269}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d['pos_counter']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[False False False False]\n",
      " [False  True False False]\n",
      " [False False False False]]\n",
      "[[0 2 2 0]\n",
      " [2 2 0 2]\n",
      " [2 0 0 0]]\n"
     ]
    }
   ],
   "source": [
    "from my.nltk_utils import tree2matrix, load_compressed_tree, find_max_f1_subtree, set_span\n",
    "string = \"(ROOT(S(NP (PRP I))(VP (VBP am)(NP (NNP Sam)))(. .)))\"\n",
    "tree = load_compressed_tree(string)\n",
    "span = (1, 3)\n",
    "set_span(tree)\n",
    "subtree = find_max_f1_subtree(tree, span)\n",
    "f = lambda t: t == subtree\n",
    "g = lambda t: 1 if isinstance(t, str) else 2\n",
    "a, b = tree2matrix(tree, f, dtype='bool')\n",
    "c, d = tree2matrix(tree, g, dtype='int32')\n",
    "print(a)\n",
    "print(c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
